<!DOCTYPE HTML>
<html lang="zh-CN">


<head>
    <meta charset="utf-8">
    <meta name="keywords" content="scikit-learn系列三：线性回归, 欢迎来到，TWOTO 的博客">
    <meta name="description" content="技术、效率、摄影">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
    <meta name="renderer" content="webkit|ie-stand|ie-comp">
    <meta name="mobile-web-app-capable" content="yes">
    <meta name="format-detection" content="telephone=no">
    <meta name="apple-mobile-web-app-capable" content="yes">
    <meta name="apple-mobile-web-app-status-bar-style" content="black-translucent">
    <!-- Global site tag (gtag.js) - Google Analytics -->


    <title>scikit-learn系列三：线性回归 | 欢迎来到，TWOTO 的博客</title>
    <link rel="icon" type="image/png" href="/twoto.png">

    <link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/awesome/css/all.css">
    <link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/materialize/materialize.min.css">
    <link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/aos/aos.css">
    <link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/animate/animate.min.css">
    <link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/lightGallery/css/lightgallery.min.css">
    <link rel="stylesheet" type="text/css" href="/css/matery.css">
    <link rel="stylesheet" type="text/css" href="/css/my.css">

    <script src="https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/jquery/jquery.min.js"></script>

<style type="text/css" lang="css">
    #loading-container{
        position: fixed;
        top: 0;
        left: 0;
        min-height: 100vh;
        width: 100vw;
        z-index: 9999;
        display: flex;
        flex-direction: column;
        justify-content: center;
        align-items: center;
        background: #FFF;
        text-align: center;
        /* loader页面消失采用渐隐的方式*/
        -webkit-transition: opacity 1s ease;
        -moz-transition: opacity 1s ease;
        -o-transition: opacity 1s ease;
        transition: opacity 1s ease;
    }
    .loading-image{
        width: 120px;
        height: 50px;
        transform: translate(-50%);
    }
    
    .loading-image div:nth-child(2) {
        -webkit-animation: pacman-balls 1s linear 0s infinite;
        animation: pacman-balls 1s linear 0s infinite
    }

    .loading-image div:nth-child(3) {
        -webkit-animation: pacman-balls 1s linear .33s infinite;
        animation: pacman-balls 1s linear .33s infinite
    }

    .loading-image div:nth-child(4) {
        -webkit-animation: pacman-balls 1s linear .66s infinite;
        animation: pacman-balls 1s linear .66s infinite
    }

    .loading-image div:nth-child(5) {
        -webkit-animation: pacman-balls 1s linear .99s infinite;
        animation: pacman-balls 1s linear .99s infinite
    }
    
   .loading-image div:first-of-type {
        width: 0;
        height: 0;
        border: 25px solid #49b1f5;
        border-right-color: transparent;
        border-radius: 25px;
        -webkit-animation: rotate_pacman_half_up .5s 0s infinite;
        animation: rotate_pacman_half_up .5s 0s infinite;
    }
    .loading-image div:nth-child(2) {
        width: 0;
        height: 0;
        border: 25px solid #49b1f5;
        border-right-color: transparent;
        border-radius: 25px;
        -webkit-animation: rotate_pacman_half_down .5s 0s infinite;
        animation: rotate_pacman_half_down .5s 0s infinite;
        margin-top: -50px;
    }
    @-webkit-keyframes rotate_pacman_half_up {0% {transform: rotate(270deg)}50% {transform: rotate(1turn)}to {transform: rotate(270deg)}}

    @keyframes rotate_pacman_half_up {0% {transform: rotate(270deg)}50% {transform: rotate(1turn)}to {transform: rotate(270deg)}}

    @-webkit-keyframes rotate_pacman_half_down {0% {transform: rotate(90deg)}50% {transform: rotate(0deg)}to {transform: rotate(90deg)}}

    @keyframes rotate_pacman_half_down {0% {transform: rotate(90deg)}50% {transform: rotate(0deg)}to {transform: rotate(90deg)}}
    
    @-webkit-keyframes pacman-balls {75% {opacity: .7}to {transform: translate(-100px, -6.25px)}}

    @keyframes pacman-balls {75% {opacity: .7}to {transform: translate(-100px, -6.25px)}}
    
   
    .loading-image div:nth-child(3),
    .loading-image div:nth-child(4),
    .loading-image div:nth-child(5),
    .loading-image div:nth-child(6){
        background-color: #49b1f5;
        width: 15px;
        height: 15px;
        border-radius: 100%;
        margin: 2px;
        width: 10px;
        height: 10px;
        position: absolute;
        transform: translateY(-6.25px);
        top: 25px;
        left: 100px;
    }
    .loading-text{
        margin-bottom: 20vh;
        text-align: center;
        color: #2c3e50;
        font-size: 2rem;
        box-sizing: border-box;
        padding: 0 10px;
        text-shadow: 0 2px 10px rgba(0,0,0,0.2);
    }
    @media only screen and (max-width: 500px) {
         .loading-text{
            font-size: 1.5rem;
         }
    }
    .fadeout {
        opacity: 0;
        filter: alpha(opacity=0);
    }
    /* logo出现动画 */
    @-webkit-keyframes fadeInDown{0%{opacity:0;-webkit-transform:translate3d(0,-100%,0);transform:translate3d(0,-100%,0)}100%{opacity:1;-webkit-transform:none;transform:none}}
    @keyframes fadeInDown{0%{opacity:0;-webkit-transform:translate3d(0,-100%,0);}}
 </style>
 <script>
(function () {
    const loaded = function(){
       setTimeout(function(){
            const loader = document.getElementById("loading-container");
            loader.className="fadeout" ;//使用渐隐的方法淡出loading page
            // document.getElementById("body-wrap").style.display="flex";
            setTimeout(function(){
                loader.style.display="none";
            },1000); 
        },1000);//强制显示loading page 1s  
    };
    loaded();
})()
 </script><meta name="generator" content="Hexo 4.2.1"><link rel="alternate" href="/atom.xml" title="欢迎来到，TWOTO 的博客" type="application/atom+xml">
<link rel="stylesheet" href="/css/prism-tomorrow.css" type="text/css">
<link rel="stylesheet" href="/css/prism-line-numbers.css" type="text/css"><style type="text/css" lang="css">
    #loading-container{
        position: fixed;
        top: 0;
        left: 0;
        min-height: 100vh;
        width: 100vw;
        z-index: 9999;
        display: flex;
        flex-direction: column;
        justify-content: center;
        align-items: center;
        background: #FFF;
        text-align: center;
        /* loader页面消失采用渐隐的方式*/
        -webkit-transition: opacity 1s ease;
        -moz-transition: opacity 1s ease;
        -o-transition: opacity 1s ease;
        transition: opacity 1s ease;
    }
    .loading-image{
        width: 120px;
        height: 50px;
        transform: translate(-50%);
    }
    
    .loading-image div:nth-child(2) {
        -webkit-animation: pacman-balls 1s linear 0s infinite;
        animation: pacman-balls 1s linear 0s infinite
    }

    .loading-image div:nth-child(3) {
        -webkit-animation: pacman-balls 1s linear .33s infinite;
        animation: pacman-balls 1s linear .33s infinite
    }

    .loading-image div:nth-child(4) {
        -webkit-animation: pacman-balls 1s linear .66s infinite;
        animation: pacman-balls 1s linear .66s infinite
    }

    .loading-image div:nth-child(5) {
        -webkit-animation: pacman-balls 1s linear .99s infinite;
        animation: pacman-balls 1s linear .99s infinite
    }
    
   .loading-image div:first-of-type {
        width: 0;
        height: 0;
        border: 25px solid #49b1f5;
        border-right-color: transparent;
        border-radius: 25px;
        -webkit-animation: rotate_pacman_half_up .5s 0s infinite;
        animation: rotate_pacman_half_up .5s 0s infinite;
    }
    .loading-image div:nth-child(2) {
        width: 0;
        height: 0;
        border: 25px solid #49b1f5;
        border-right-color: transparent;
        border-radius: 25px;
        -webkit-animation: rotate_pacman_half_down .5s 0s infinite;
        animation: rotate_pacman_half_down .5s 0s infinite;
        margin-top: -50px;
    }
    @-webkit-keyframes rotate_pacman_half_up {0% {transform: rotate(270deg)}50% {transform: rotate(1turn)}to {transform: rotate(270deg)}}

    @keyframes rotate_pacman_half_up {0% {transform: rotate(270deg)}50% {transform: rotate(1turn)}to {transform: rotate(270deg)}}

    @-webkit-keyframes rotate_pacman_half_down {0% {transform: rotate(90deg)}50% {transform: rotate(0deg)}to {transform: rotate(90deg)}}

    @keyframes rotate_pacman_half_down {0% {transform: rotate(90deg)}50% {transform: rotate(0deg)}to {transform: rotate(90deg)}}
    
    @-webkit-keyframes pacman-balls {75% {opacity: .7}to {transform: translate(-100px, -6.25px)}}

    @keyframes pacman-balls {75% {opacity: .7}to {transform: translate(-100px, -6.25px)}}
    
   
    .loading-image div:nth-child(3),
    .loading-image div:nth-child(4),
    .loading-image div:nth-child(5),
    .loading-image div:nth-child(6){
        background-color: #49b1f5;
        width: 15px;
        height: 15px;
        border-radius: 100%;
        margin: 2px;
        width: 10px;
        height: 10px;
        position: absolute;
        transform: translateY(-6.25px);
        top: 25px;
        left: 100px;
    }
    .loading-text{
        margin-bottom: 20vh;
        text-align: center;
        color: #2c3e50;
        font-size: 2rem;
        box-sizing: border-box;
        padding: 0 10px;
        text-shadow: 0 2px 10px rgba(0,0,0,0.2);
    }
    @media only screen and (max-width: 500px) {
         .loading-text{
            font-size: 1.5rem;
         }
    }
    .fadeout {
        opacity: 0;
        filter: alpha(opacity=0);
    }
    /* logo出现动画 */
    @-webkit-keyframes fadeInDown{0%{opacity:0;-webkit-transform:translate3d(0,-100%,0);transform:translate3d(0,-100%,0)}100%{opacity:1;-webkit-transform:none;transform:none}}
    @keyframes fadeInDown{0%{opacity:0;-webkit-transform:translate3d(0,-100%,0);}}
 </style>
 <script>
(function () {
    const loaded = function(){
       setTimeout(function(){
            const loader = document.getElementById("loading-container");
            loader.className="fadeout" ;//使用渐隐的方法淡出loading page
            // document.getElementById("body-wrap").style.display="flex";
            setTimeout(function(){
                loader.style.display="none";
            },1000); 
        },1000);//强制显示loading page 1s  
    };
    loaded();
})()
 </script></head>



 <div id="loading-container">
     <p class="loading-text">玩命加载中 . . . </p> 
     <div class="loading-image">
         <div></div>
         <div></div>
         <div></div>
         <div></div> 
         <div></div>
     </div>
 </div><body>
<!--动态线条背景-->
<script type="text/javascript"
color="122 103 238" opacity='0.5' zIndex="-1" count="200"
src="//cdn.bootcss.com/canvas-nest.js/1.0.0/canvas-nest.min.js">
</script>

<header class="navbar-fixed">
    <nav id="headNav" class="bg-color nav-transparent">
        <div id="navContainer" class="nav-wrapper container">
            <div class="brand-logo">
                <a href="/" class="waves-effect waves-light">
                    
                    <img src="/medias/mylogo.png" class="logo-img" alt="LOGO">
                    
                    <span class="logo-span">欢迎来到，TWOTO 的博客</span>
                </a>
            </div>
            

<a href="#" data-target="mobile-nav" class="sidenav-trigger button-collapse"><i class="fas fa-bars"></i></a>
<ul class="right nav-menu">
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/" class="waves-effect waves-light">
      
      <i class="fas fa-home" style="zoom: 0.6;"></i>
      
      <span>首页</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/tags" class="waves-effect waves-light">
      
      <i class="fas fa-tags" style="zoom: 0.6;"></i>
      
      <span>标签</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/categories" class="waves-effect waves-light">
      
      <i class="fas fa-bookmark" style="zoom: 0.6;"></i>
      
      <span>分类</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/archives" class="waves-effect waves-light">
      
      <i class="fas fa-archive" style="zoom: 0.6;"></i>
      
      <span>归档</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/about" class="waves-effect waves-light">
      
      <i class="fas fa-user-circle" style="zoom: 0.6;"></i>
      
      <span>关于</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/contact" class="waves-effect waves-light">
      
      <i class="fas fa-comments" style="zoom: 0.6;"></i>
      
      <span>留言板</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/friends" class="waves-effect waves-light">
      
      <i class="fas fa-address-book" style="zoom: 0.6;"></i>
      
      <span>友情链接</span>
    </a>
    
  </li>
  
  <li>
    <a href="#searchModal" class="modal-trigger waves-effect waves-light">
      <i id="searchIcon" class="fas fa-search" title="搜索" style="zoom: 0.85;"></i>
    </a>
  </li>
</ul>


<div id="mobile-nav" class="side-nav sidenav">

    <div class="mobile-head bg-color">
        
        <img src="/medias/mylogo.png" class="logo-img circle responsive-img">
        
        <div class="logo-name">欢迎来到，TWOTO 的博客</div>
        <div class="logo-desc">
            
            技术、效率、摄影
            
        </div>
    </div>

    

    <ul class="menu-list mobile-menu-list">
        
        <li class="m-nav-item">
	  
		<a href="/" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-home"></i>
			
			首页
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/tags" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-tags"></i>
			
			标签
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/categories" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-bookmark"></i>
			
			分类
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/archives" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-archive"></i>
			
			归档
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/about" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-user-circle"></i>
			
			关于
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/contact" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-comments"></i>
			
			留言板
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/friends" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-address-book"></i>
			
			友情链接
		</a>
          
        </li>
        
        
        <li><div class="divider"></div></li>
        <li>
            <a href="https://github.com/DongZhouGu" class="waves-effect waves-light" target="_blank">
                <i class="fab fa-github-square fa-fw"></i>Fork Me
            </a>
        </li>
        
    </ul>
</div>


        </div>

        
            <style>
    .nav-transparent .github-corner {
        display: none !important;
    }

    .github-corner {
        position: absolute;
        z-index: 10;
        top: 0;
        right: 0;
        border: 0;
        transform: scale(1.1);
    }

    .github-corner svg {
        color: #0f9d58;
        fill: #fff;
        height: 64px;
        width: 64px;
    }

    .github-corner:hover .octo-arm {
        animation: a 0.56s ease-in-out;
    }

    .github-corner .octo-arm {
        animation: none;
    }

    @keyframes a {
        0%,
        to {
            transform: rotate(0);
        }
        20%,
        60% {
            transform: rotate(-25deg);
        }
        40%,
        80% {
            transform: rotate(10deg);
        }
    }
</style>

<a href="https://github.com/DongZhouGu" class="github-corner tooltipped hide-on-med-and-down" target="_blank"
   data-tooltip="Fork Me" data-position="left" data-delay="50">
    <svg viewBox="0 0 250 250" aria-hidden="true">
        <path d="M0,0 L115,115 L130,115 L142,142 L250,250 L250,0 Z"></path>
        <path d="M128.3,109.0 C113.8,99.7 119.0,89.6 119.0,89.6 C122.0,82.7 120.5,78.6 120.5,78.6 C119.2,72.0 123.4,76.3 123.4,76.3 C127.3,80.9 125.5,87.3 125.5,87.3 C122.9,97.6 130.6,101.9 134.4,103.2"
              fill="currentColor" style="transform-origin: 130px 106px;" class="octo-arm"></path>
        <path d="M115.0,115.0 C114.9,115.1 118.7,116.5 119.8,115.4 L133.7,101.6 C136.9,99.2 139.9,98.4 142.2,98.6 C133.8,88.0 127.5,74.4 143.8,58.0 C148.5,53.4 154.0,51.2 159.7,51.0 C160.3,49.4 163.2,43.6 171.4,40.1 C171.4,40.1 176.1,42.5 178.8,56.2 C183.1,58.6 187.2,61.8 190.9,65.4 C194.5,69.0 197.7,73.2 200.1,77.6 C213.8,80.2 216.3,84.9 216.3,84.9 C212.7,93.1 206.9,96.0 205.4,96.6 C205.1,102.4 203.0,107.8 198.3,112.5 C181.9,128.9 168.3,122.5 157.7,114.1 C157.9,116.9 156.7,120.9 152.7,124.9 L141.0,136.5 C139.8,137.7 141.6,141.9 141.8,141.8 Z"
              fill="currentColor" class="octo-body"></path>
    </svg>
</a>
        
    </nav>

</header>





<div class="bg-cover pd-header post-cover" style="background-image: url('https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/medias/featureimages/13.jpg')">
    <div class="container" style="right: 0px;left: 0px;">
        <div class="row">
            <div class="col s12 m12 l12">
                <div class="brand">
                    <h1 class="description center-align post-title">scikit-learn系列三：线性回归</h1>
                </div>
            </div>
        </div>
    </div>
</div>




<main class="post-container content">

    
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/tocbot/tocbot.css">
<style>
    #articleContent h1::before,
    #articleContent h2::before,
    #articleContent h3::before,
    #articleContent h4::before,
    #articleContent h5::before,
    #articleContent h6::before {
        display: block;
        content: " ";
        height: 100px;
        margin-top: -100px;
        visibility: hidden;
    }

    #articleContent :focus {
        outline: none;
    }

    .toc-fixed {
        position: fixed;
        top: 64px;
    }

    .toc-widget {
        width: 345px;
        padding-left: 20px;
    }

    .toc-widget .toc-title {
        margin: 35px 0 15px 0;
        padding-left: 17px;
        font-size: 1.5rem;
        font-weight: bold;
        line-height: 1.5rem;
    }

    .toc-widget ol {
        padding: 0;
        list-style: none;
    }

    #toc-content {
        height: calc(100vh - 250px);
        overflow: auto;
    }

    #toc-content ol {
        padding-left: 10px;
    }

    #toc-content ol li {
        padding-left: 10px;
    }

    #toc-content .toc-link:hover {
        color: #42b983;
        font-weight: 700;
        text-decoration: underline;
    }

    #toc-content .toc-link::before {
        background-color: transparent;
        max-height: 25px;

        position: absolute;
        right: 23.5vw;
        display: block;
    }

    #toc-content .is-active-link {
        color: #42b983;
    }

    #floating-toc-btn {
        position: fixed;
        right: 15px;
        bottom: 76px;
        padding-top: 15px;
        margin-bottom: 0;
        z-index: 998;
    }

    #floating-toc-btn .btn-floating {
        width: 48px;
        height: 48px;
    }

    #floating-toc-btn .btn-floating i {
        line-height: 48px;
        font-size: 1.4rem;
    }
</style>
<div class="row">
    <div id="main-content" class="col s12 m12 l9">
        <!-- 文章内容详情 -->
<div id="artDetail">
    <div class="card">
        <div class="card-content article-info">
            <div class="row tag-cate">
                <div class="col s7">
                    
                    <div class="article-tag">
                        
                            <a href="/tags/%E5%9F%BA%E7%A1%80%E7%9F%A5%E8%AF%86/">
                                <span class="chip bg-color">基础知识</span>
                            </a>
                        
                            <a href="/tags/ML%E7%AE%97%E6%B3%95/">
                                <span class="chip bg-color">ML算法</span>
                            </a>
                        
                    </div>
                    
                </div>
                <div class="col s5 right-align">
                    
                    <div class="post-cate">
                        <i class="fas fa-bookmark fa-fw icon-category"></i>
                        
                            <a href="/categories/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/" class="post-category">
                                机器学习
                            </a>
                        
                    </div>
                    
                </div>
            </div>

            <div class="post-info">
                
                <div class="post-date info-break-policy">
                    <i class="far fa-calendar-minus fa-fw"></i>发布日期:&nbsp;&nbsp;
                    2020-06-30
                </div>
                

                

                
                <div class="info-break-policy">
                    <i class="far fa-file-word fa-fw"></i>文章字数:&nbsp;&nbsp;
                    7k
                </div>
                

                

                
                    <div id="busuanzi_container_page_pv" class="info-break-policy">
                        <i class="far fa-eye fa-fw"></i>阅读次数:&nbsp;&nbsp;
                        <span id="busuanzi_value_page_pv"></span>
                    </div>
				
            </div>
        </div>
        <hr class="clearfix">
        <div class="card-content article-card-content">
            <div id="articleContent">
                <h2 id="项目地址传送门，欢迎-star-和-fork-！"><a href="#项目地址传送门，欢迎-star-和-fork-！" class="headerlink" title="项目地址传送门，欢迎 star 和 fork ！"></a>项目地址<a href="https://github.com/DongZhouGu/scikit-learn-ml" target="_blank" rel="noopener">传送门</a>，欢迎 star 和 fork ！</h2><h1 id="线性回归算法"><a href="#线性回归算法" class="headerlink" title="线性回归算法"></a>线性回归算法</h1><p>线性回归算法是使用线性方程对数据集拟合的算法，本文从单变量线性回归算法、多变量线性回归算法，其中损失函数以及梯度下降算法的推导过程会用到部分线性代数和偏导数；接着重点介绍了梯度下降算法的求解步骤以及性能优化方面的内容；最后通过一个房价预测模型，介绍了线性回归算法性能优化的一些常用步骤和方法。</p>
<h2 id="线性回归概述"><a href="#线性回归概述" class="headerlink" title="线性回归概述"></a>线性回归概述</h2><p>说到回归，一般都是指 <code>线性回归(linear regression)</code>。线性回归意味着可以将输入项分别乘以一些常量，再将结果加起来得到输出。回归的目的是预测数值型的目标值，最直接的办法是依据输入写出一个目标值的计算公式。</p>
<p>假如你想要预测兰博基尼跑车的功率大小，可能会这样计算:</p>
<blockquote>
<p>HorsePower = 0.0015 * annualSalary - 0.99 * hoursListeningToPublicRadio</p>
</blockquote>
<p>这就是所谓的 <code>回归方程(regression equation)</code>，其中的 0.0015 和 -0.99 称作 <code>回归系数（regression weights）</code>，求这些回归系数的过程就是回归。一旦有了这些回归系数，再给定输入，做预测就非常容易了。具体的做法是用回归系数乘以输入值，再将结果全部加在一起，就得到了预测值。我们这里所说的，回归系数是一个向量，输入也是向量，这些运算也就是求出二者的内积。</p>
<h2 id="单变量线性回归算法"><a href="#单变量线性回归算法" class="headerlink" title="单变量线性回归算法"></a>单变量线性回归算法</h2><p>先考虑最简单的单变量线性回归算法，即只有一个输入特征。</p>
<h3 id="预测函数"><a href="#预测函数" class="headerlink" title="预测函数"></a>预测函数</h3><p>针对数据集x和y，预测函数会根据输入特征x来计算输出值h(x)。其输入和输出的函数关系如下：<br>$$<br>h_{\theta}(x)=\theta_{0}+\theta_{1} x<br>$$</p>
<p>这个方程表达的是一条直线。我们的任务是构造一个 $h_{\theta}$ 函数，来映射数据集中的输入特征x和输出值y，使得预测函数 $h_{\theta}$ 计算出来的值与真实值y的整体误差最小。构造  $h_{\theta}$ 函数的关键就是找到合适的 $\theta_{0}$和 $\theta_{1}$ 的值， 模型参数，也就是所说的模型参数。</p>
<p>假设有如下的数据集：</p>
<table>
<thead>
<tr>
<th align="center">输入特征x</th>
<th align="center">输出y</th>
</tr>
</thead>
<tbody><tr>
<td align="center">1</td>
<td align="center">4</td>
</tr>
<tr>
<td align="center">2</td>
<td align="center">6</td>
</tr>
<tr>
<td align="center">3</td>
<td align="center">10</td>
</tr>
<tr>
<td align="center">4</td>
<td align="center">15</td>
</tr>
</tbody></table>
<p>假设模型参数 $\theta_{0}=1,  \theta_{1}=3$ ,  则预测函数为 $h_{\theta}(x)=1+3 x$ 。针对数据集中的第一个样本，输入为1，根据模型函数预测出来的值是4，与输出值y是吻合的。针对第二个样本，输入为2，根据模型函数预测出来的值是7，与实际输出值y相差1。模型的求解过程就是找出一组最合适的模型参数 $\theta_{0}$和 $\theta_{1}$，以便能最好地拟合数据集。</p>
<p>怎样来判断最好地拟合了数据集呢？没错，就是使用损失函数（也叫损失函数）。当拟合损失最小时，即找到了最好的拟合参数。</p>
<h3 id="损失函数"><a href="#损失函数" class="headerlink" title="损失函数"></a>损失函数</h3><p>单变量线性回归算法的损失函数是：</p>
<p>$$<br>J(\theta)=J\left(\theta_{0}, \theta_{1}\right)=\frac{1}{2 m} \sum_{i=1}^{m}\left(h\left(x^{(i)}\right)-y^{(i)}\right)^{2}<br>$$<br>其中，$h\left(x^{(i)}\right)-y^{(i)}$ 是预测值和真实值之间的误差，故损失就是预测值和真实值之间误差平方的平均值，之所以乘以1/2是为了方便计算。这个函数也称为均方差公式。有了损失函数，就可以精确地测量模型对训练样本拟合的好坏程度。</p>
<h3 id="梯度下降算法"><a href="#梯度下降算法" class="headerlink" title="梯度下降算法"></a>梯度下降算法</h3><p>有了预测函数，也可以精确地测量预测函数对训练样本的拟合情况。但怎么求解模型参数 $\theta_{0}$和 $\theta_{1}$的值呢？这时梯度下降算法就排上了用场。</p>
<p>我们的任务是找到合适的 $\theta_{0}$和 $\theta_{1}$ ，使得损失函数 $J\left(\theta_{0}, \theta_{1}\right)$ 最小。为了便于理解，我们切换到三维空间来描述这个任务。在一个三维空间里，以  $\theta_{0}$ 作为 x 轴， 以 $\theta_{1}$ 作为 y 轴，以损失函数 $J\left(\theta_{0}, \theta_{1}\right)$ 作为 z 轴，那么我们的任务就是要找出当 z 轴上的值最小的时候所对应的 x 轴上的值和 y 轴上的值。</p>
<p><strong>梯度下降算法的原理：</strong>先随机选择一组 $\theta_{0}$ 和 $\theta_{1}$ ，同时选择一个参数 $\alpha$ 作为移动的步长。然后，让x轴上的 $\theta_{0}$ 和 y轴上的  $\theta_{1}$ 别向特定的方向移动一小步，这个步长的大小就由参数  $\alpha$ 决定。经过多次迭代之后，x 轴和 y 轴上的值决定的点就慢慢靠近 z 轴上的最小值处，如图所示。</p>
<p><img src="/medias/loading.gif" data-original="https://cdn.jsdelivr.net/gh/dongzhougu/imageuse1/u=3721595541,2272727131&amp;fm=26&amp;gp=0.jpg" alt="img"></p>
<p>那特定的方向怎么确定呢？答案是<strong>偏导数</strong>。</p>
<p>可以简单地把偏导数理解为斜率。我们要让 $\theta_{j}$ 不停地迭代，由当前  $\theta_{j}$ 的值，根据 $J(\theta)$ 的偏导数函数，算出 $J(\theta)$ 在  $\theta_{j}$ 上的斜率，然后在乘以学习率  $\alpha$ ，就可以让 $\theta_{j}$ 往 $J(\theta)$ 变小的方向迈一小步。</p>
<p>用数学来描述上述过程，梯度下降的公式为：<br>$$<br>\theta_{j}=\theta_{j}-\alpha \frac{\partial}{\partial \theta_{j}} J(\theta)<br>$$<br>把损失函数 $J(\theta)$ 的定义代入上面的公式中，不难推导出梯度下降算法公式：<br>$$<br>\begin{array}{c}<br>\theta_{0}=\theta_{0}-\frac{\alpha}{m} \sum_{i=1}^{m}\left(h\left(x^{(i)}\right)-y^{(j)}\right) \<br>\<br>\theta_{1}=\theta_{1}-\frac{\alpha}{m} \sum_{i=1}^{m}\left(\left(h\left(x^{(i)}\right)-y^{(i)}\right) x_{i}\right)<br>\end{array}<br>$$<br>公式中， $\alpha$  是学习率；m 是训练样本的个数: $h\left(x^{(i)}\right)-y^{(i)}$ 是模型预测值和真实值的误差。需要注意的是，针对</p>
<p> $\theta_{0}$ 和 $\theta_{1}$ 分别求出了其迭代公式，在 $\theta_{1}$ 的迭代公式里，累加器中还需要乘以 $x_{i}$, 具体参考扩展部分。</p>
<h2 id="多变量线性回归算法"><a href="#多变量线性回归算法" class="headerlink" title="多变量线性回归算法"></a>多变量线性回归算法</h2><p>实际应用中往往不止一个输入特征。熟悉了单变量线性回归算法后，我们来探讨一下多变量线性回归算法。</p>
<h3 id="预测函数-1"><a href="#预测函数-1" class="headerlink" title="预测函数"></a>预测函数</h3><p>上面介绍的单变量线性回归模型里只有一个输入特征，我们推广到更一般的情况，即多个输入特征。此时输出y的值由n个输入特征 $x_{1}, x_{2}, \ldots, x_{n}$ 决定。那么预测函数模型可以改写如下：</p>
<p>$$<br>h_{\theta}(x)=\theta_{0}+\theta_{1} x_{1}+\theta_{2} x_{2}+\ldots+\theta_{n} x_{n}<br>$$<br>假设 $x_{0}=1$，那么上面的公式可以重写为：<br>$$<br>h_{\theta}(x)=\sum_{j=0}^{n} \theta_{j} x_{j}<br>$$<br>其中，$\theta_{0}, \theta_{1}, \dots, \theta_{n}$ 统称为 $\theta$ , 是预测函数的参数。即一组 $\theta$ 值就决定了一个预测函数，记为 $h_{\theta}(x)$ , 为了简便起见，在不引起误解的情况下可以简写为 $h(x)$ 。理论上，预测函数有无穷多个，我们求解的目标就是找出一个最优的 $\theta$ 值。</p>
<h4 id="向量形式的预测函数"><a href="#向量形式的预测函数" class="headerlink" title="向量形式的预测函数"></a>向量形式的预测函数</h4><p>根据向量乘法运算法则，损失函数可重写为：</p>
<p>$$<br>h_{\theta}(x)=\left[\theta_{0}, \theta_{1}, \cdots, \theta_{n}\right]\left[\begin{array}{c}<br>x_{0} \<br>x_{1} \<br>\vdots \<br>x_{n}<br>\end{array}\right]=\theta^{T} x<br>$$<br>此处，依然假设 $x_{0}=1$， $x_{0}$ 称为模型偏置（bias）。</p>
<p>写成向量形式的预测函数有两个原因。一是因为简洁，二是因为在实现算法时，要用到数值计算里的矩阵运算来提高效率，比如 <code>Numpy</code> 库里的矩阵运算。</p>
<h4 id="向量形式的训练样本"><a href="#向量形式的训练样本" class="headerlink" title="向量形式的训练样本"></a>向量形式的训练样本</h4><p>假设输入特征的个数是n，即 $x_{1}, x_{2}, \ldots, x_{n}$ , 我们总共有 m 个训练样本，为了书写方便，假设 $x_{0}=1$。这样训练样本可以写成矩阵的形式，即矩阵里每一行都是一个训练样本，总共有 m 行，每行有 n+1 列。</p>
<blockquote>
<p>思考：为什么不是n列而是n+1列？答案是：把模型偏置 $x_{0}$也加入了训练样本里。最后把训练样本写成一个矩阵，如下：</p>
</blockquote>
<p>$$<br>\boldsymbol{X}=\left[\begin{array}{ccccc}<br>x_{0}^{(1)} &amp; x_{1}^{(1)} &amp; x_{2}^{(1)} &amp; \dots &amp; x_{n}^{(1)} \<br>x_{0}^{(2)} &amp; x_{1}^{(2)} &amp; x_{2}^{(2)} &amp; \dots &amp; x_{n}^{(2)} \<br>\vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \vdots \<br>x_{0}^{(m)} &amp; x_{1}^{(m)} &amp; x_{2}^{(m)} &amp; \cdots &amp; x_{n}^{(m)}<br>\end{array}\right], \theta=\left[\begin{array}{c}<br>\theta_{0} \<br>\theta_{1} \<br>\theta_{2} \<br>\vdots \<br>\theta_{n}<br>\end{array}\right]<br>$$</p>
<p>理解训练样本矩阵的关键在于理解这些上标和下标的含义。其中，带括号的上标表示样本序号，从1到m；下标表示特征序号，从0到n，其中 $x_{0}$ 为常数1。</p>
<blockquote>
<p>$x_{j}^{(i)}$ 表示第 i 个训练样本的第 j 个特征的值。而 $x^{(i)}$ 只有上标，则表示第 i 个训练样本所构成的列向量。</p>
</blockquote>
<p>综上，训练样本的预测值 $h_{\theta}(X)$ ，可以使用下面的矩阵运算公式：</p>
<p>$$<br>h_{\theta}(X)=X \theta<br>$$</p>
<h3 id="损失函数-1"><a href="#损失函数-1" class="headerlink" title="损失函数"></a>损失函数</h3><p>多变量线性回归算法的损失函数：</p>
<p>$$<br>J(\theta)=\frac{1}{2 m} \sum_{i=1}^{m}\left(h\left(x^{(i)}\right)-y^{(i)}\right)^{2}<br>$$<br>其中，模型参数 $\theta$ 为 n+1 维的向量，$h\left(x^{(i)}\right)-y^{(i)}$ 是预测值和实际值的差，这个形式和单变量线性回归算法的类似。</p>
<p>损失函数有其对应的矩阵形式：<br>$$<br>J(\theta)=\frac{1}{2 m}(X \theta-\vec{y})^{T}(X \theta-\vec{y})<br>$$<br>其中，X 为 $m \times(n+1)$ 维的训练样本矩阵；上标T表示转置矩阵；$\vec{y}$ 表示由所有的训练样本的输出 $y^{(i)}$ 构成的向量。这个公式的优势是：没有累加器，不需要循环，直接使用矩阵运算，就可以一次性计算出对特定的参数 $\theta$ 下模型的拟合损失。</p>
<h3 id="梯度下降算法-1"><a href="#梯度下降算法-1" class="headerlink" title="梯度下降算法"></a>梯度下降算法</h3><p>根据单变量线性回归算法的介绍，梯度下降的公式为：<br>$$<br>\theta_{j}=\theta_{j}-\alpha \frac{\partial}{\partial \theta_{j}} J(\theta)<br>$$<br>公式中，下标 j 是参数的序号，其值从 0 到 n； $\alpha$ 为学习率。把损失函数代入上式，利用偏导数计算法则，不难推导出梯度下降算法的参数迭代公式：<br>$$<br>\theta_{j}=\theta_{j}-\frac{\alpha}{m} \sum_{i=1}^{m}\left(\left(h\left(x^{(i)}\right)-y^{(i)}\right) x_{j}^{(i)}\right)<br>$$<br>我们可以对比一下单变量线性回归函数的参数迭代公式。实际上和多变量线性回归函数的参数迭代公式是一模一样的。惟一的区别就是因为 $x_{0}$ 为常数1，在单变量线性回归算法的参数迭代公式中省去了。</p>
<p>应用这个公式编写机器学习算法，一般步骤如下：</p>
<ul>
<li><p>确定学习率： $\alpha$ 太大可能会使损失函数无法收敛，太小则计算太多，机器学习算法效率就比较低。</p>
</li>
<li><p>参数初始化：比如让所有的参数都以1作为起始点，$\theta_{0}=1, \theta_{1}=1, \dots, \theta_{n}=1$，根据预测值和损失函数，就可以算出在参数起始位置的损失。需要注意的是，参数起始点可以根据实际情况灵活选择，以便让机器学习算法的性能更高，比如选择比较靠近极点的位置。</p>
</li>
<li><p>计算参数的下一组值：据梯度下降参数迭代公式，分别同时计算出新的 $\theta_{j}$ 值，进而得到新的预测函数 $h_{\theta}(x)$ 。再根据新的预测函数，代入损失函数就可以算出新的损失。</p>
</li>
<li><p>确定损失函数是否收敛：拿新的和旧的损失进行比较，看损失是不是变得越来越小。如果两次损失之间的差异小于误差范围，即说明已经非常靠近最小损失了，就可以近似地认为我们找到了最小损失。如果两次损失之间的差异在误差范围之外，重复步骤（3）继续计算下一组参数直到找到最优解。</p>
</li>
</ul>
<h2 id="模型优化"><a href="#模型优化" class="headerlink" title="模型优化"></a>模型优化</h2><p>线性回归模型常用的优化方法，包括增加多项式特征以及数据归一化处理等。</p>
<h3 id="多项式与线性回归"><a href="#多项式与线性回归" class="headerlink" title="多项式与线性回归"></a>多项式与线性回归</h3><p>当线性回归模型太简单导致欠拟合时，我们可以增加特征多项式来让线性回归模型更好地拟合数据。比如有两个特征  $x_{1}$ 和 $x_{2}$ ，可以增加两个特征的乘积 $x_{1} \times x_{2}$ 作为新特征  $x_{3}$ 。同理，我们也可以增加 $x_{1}^{2}$ 和 $x_{2}^{2}$  分别作为新特征  $x_{4}$ 和 $x_{5}$ 。</p>
<p>在 <code>scikit-learn</code> 里，线性回归是由类 <code>sklearn.learn_model.LinearRegression</code> 实现的，多项式由类<code>sklearn.preprocessing.PolynomialFeatures</code> 实现。那么要怎样添加多项式特征呢？我们需要用一个管道把两个类串起来，即用 <code>sklearn.pipeline.Pipeline</code> 把这两个模型串起来。</p>
<p>比如下面的函数就可以创建一个多项式拟合：</p>
<pre class="line-numbers language-python"><code class="language-python"><span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>linear_model <span class="token keyword">import</span> LinearRegression
<span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>preprocessing <span class="token keyword">import</span> PolynomialFeatures
<span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>pipeline <span class="token keyword">import</span> Pipeline
<span class="token keyword">def</span> <span class="token function">polynomial_model</span><span class="token punctuation">(</span>degree<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    polynomial_features <span class="token operator">=</span> PolynomialFeatures<span class="token punctuation">(</span>degree<span class="token operator">=</span>degree<span class="token punctuation">,</span>include_bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    linear_regression <span class="token operator">=</span> LinearRegression<span class="token punctuation">(</span>normalize<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
    <span class="token comment" spellcheck="true"># 这是一个流水线，先增加多项式阶数，然后再用线性回归算法来拟合数据</span>
    pipeline <span class="token operator">=</span> Pipeline<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">(</span><span class="token string">"polynomial_features"</span><span class="token punctuation">,</span> polynomial_features<span class="token punctuation">)</span><span class="token punctuation">,</span>
                         <span class="token punctuation">(</span><span class="token string">"linear_regression"</span><span class="token punctuation">,</span> linear_regression<span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> pipeline<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>一个 Pipeline 可以包含多个处理节点，在 scikit-learn 里，除了最后一个节点外，其他的节点都必须实现 fit() 方法和 transform() 方法，最后一个节点只需要实现 fit() 方法即可。当训练样本数据送进 Pipeline 里进行处理时，它会逐个调用节点的 fit() 方法和 transform() 方法，最后调用最后一个节点的 fit() 方法来拟合数据。管道的示意图如下所示：</p>
<p><img src="/medias/loading.gif" data-original="https://cdn.jsdelivr.net/gh/dongzhougu/imageuse1/image-20200630093109778.png" alt="image-20200630090937011"></p>
<h3 id="数据归一化"><a href="#数据归一化" class="headerlink" title="数据归一化"></a>数据归一化</h3><p>当线性回归模型有多个输入特征时，特别是使用多项式添加特征时，需要对数据进行归一化处理。比如，特征</p>
<p>$x_{1}$ 的范围在[1,4]之间，特征 $x_{2}$ 的范文在[1,2000]之间，这种情况下，可以让 $x_{1}$除以4来作为新特征 $x_{1}$，同时让 $x_{2}$ </p>
<p>除以2000来作为新特征 $x_{2}$ ，该过程称为特征缩放（feature scaling）。可以使用特征缩放来对训练样本进行归一化处理，处理后的特征范围在[0,1]之间。</p>
<ul>
<li>归一化处理的目的是让算法收敛更快，提升模型拟合过程中的计算效率。</li>
<li>进行归一化处理后，当有个新的样本需要计算预测值时，也需要先进行归一化处理，再通过模型来计算预测值，计算出来的预测值要再乘以归一化处理的系数，这样得到的数据才是真正的预测数据。</li>
<li>在 <code>scikit-learn</code> 里，使用 <code>LinearRegression</code> 进行线性回归时，可以指定 <code>normalize=True</code> 来对数据进行归一化处理。</li>
</ul>
<h2 id="示例1：使用线性回归算法拟合正弦函数"><a href="#示例1：使用线性回归算法拟合正弦函数" class="headerlink" title="示例1：使用线性回归算法拟合正弦函数"></a>示例1：使用线性回归算法拟合正弦函数</h2><p>首先生成200个在区间 $[2 \pi, 2 \pi]$ 内的正弦函数上的点，并给这些点加上一些随机的噪声。</p>
<pre class="line-numbers language-python"><code class="language-python"><span class="token keyword">import</span> numpy <span class="token keyword">as</span> np
n_dots <span class="token operator">=</span> <span class="token number">200</span>
X <span class="token operator">=</span> np<span class="token punctuation">.</span>linspace<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">2</span> <span class="token operator">*</span> np<span class="token punctuation">.</span>pi<span class="token punctuation">,</span> <span class="token number">2</span> <span class="token operator">*</span> np<span class="token punctuation">.</span>pi<span class="token punctuation">,</span> n_dots<span class="token punctuation">)</span>
Y <span class="token operator">=</span> np<span class="token punctuation">.</span>sin<span class="token punctuation">(</span>X<span class="token punctuation">)</span> <span class="token operator">+</span> <span class="token number">0.2</span> <span class="token operator">*</span> np<span class="token punctuation">.</span>random<span class="token punctuation">.</span>rand<span class="token punctuation">(</span>n_dots<span class="token punctuation">)</span> <span class="token operator">-</span> <span class="token number">0.1</span>
<span class="token comment" spellcheck="true"># 把一个n维向量转换成一个n*1维的矩阵</span>
X <span class="token operator">=</span> X<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
Y <span class="token operator">=</span> Y<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">;</span><span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>使用 <code>PolynomialFeatures</code> 和 <code>Pipeline</code> 创建一个多项式拟合模型</p>
<pre class="line-numbers language-python"><code class="language-python"><span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>linear_model <span class="token keyword">import</span> LinearRegression
<span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>preprocessing <span class="token keyword">import</span> PolynomialFeatures
<span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>pipeline <span class="token keyword">import</span> Pipeline
<span class="token keyword">def</span> <span class="token function">polynomial_model</span><span class="token punctuation">(</span>degree<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    polynomial_features <span class="token operator">=</span> PolynomialFeatures<span class="token punctuation">(</span>degree<span class="token operator">=</span>degree<span class="token punctuation">,</span>include_bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    linear_regression <span class="token operator">=</span> LinearRegression<span class="token punctuation">(</span>normalize<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
    pipeline <span class="token operator">=</span> Pipeline<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">(</span><span class="token string">"polynomial_features"</span><span class="token punctuation">,</span> polynomial_features<span class="token punctuation">)</span><span class="token punctuation">,</span>
                         <span class="token punctuation">(</span><span class="token string">"linear_regression"</span><span class="token punctuation">,</span> linear_regression<span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> pipeline<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>分别用2/3/5/10阶多项式来拟合数据集：</p>
<pre class="line-numbers language-python"><code class="language-python"><span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>metrics <span class="token keyword">import</span> mean_squared_error
degrees <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">10</span><span class="token punctuation">]</span>
results <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
<span class="token keyword">for</span> d <span class="token keyword">in</span> degrees<span class="token punctuation">:</span>
    model <span class="token operator">=</span> polynomial_model<span class="token punctuation">(</span>degree<span class="token operator">=</span>d<span class="token punctuation">)</span>
    model<span class="token punctuation">.</span>fit<span class="token punctuation">(</span>X<span class="token punctuation">,</span> Y<span class="token punctuation">)</span>
    train_score <span class="token operator">=</span> model<span class="token punctuation">.</span>score<span class="token punctuation">(</span>X<span class="token punctuation">,</span> Y<span class="token punctuation">)</span>
    mse <span class="token operator">=</span> mean_squared_error<span class="token punctuation">(</span>Y<span class="token punctuation">,</span> model<span class="token punctuation">.</span>predict<span class="token punctuation">(</span>X<span class="token punctuation">)</span><span class="token punctuation">)</span>
    results<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token punctuation">{</span><span class="token string">"model"</span><span class="token punctuation">:</span> model<span class="token punctuation">,</span> <span class="token string">"degree"</span><span class="token punctuation">:</span> d<span class="token punctuation">,</span> <span class="token string">"score"</span><span class="token punctuation">:</span> train_score<span class="token punctuation">,</span> <span class="token string">"mse"</span><span class="token punctuation">:</span> mse<span class="token punctuation">}</span><span class="token punctuation">)</span>
<span class="token keyword">for</span> r <span class="token keyword">in</span> results<span class="token punctuation">:</span>
    <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"degree: {}; train score: {}; mean squared error: {}"</span>
          <span class="token punctuation">.</span>format<span class="token punctuation">(</span>r<span class="token punctuation">[</span><span class="token string">"degree"</span><span class="token punctuation">]</span><span class="token punctuation">,</span> r<span class="token punctuation">[</span><span class="token string">"score"</span><span class="token punctuation">]</span><span class="token punctuation">,</span> r<span class="token punctuation">[</span><span class="token string">"mse"</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>算出每个模型拟合的评分，此外，使用 <code>mean_squared_error</code> 算出均方根误差，即实际的点和模型预点之间的距离，均方根误差越小说明模型拟合效果越好——上述代码的输出结果为：</p>
<pre class="line-numbers language-python"><code class="language-python">degree<span class="token punctuation">:</span> <span class="token number">2</span><span class="token punctuation">;</span> train score<span class="token punctuation">:</span> <span class="token number">0.1543189069883787</span><span class="token punctuation">;</span> mean squared error<span class="token punctuation">:</span> <span class="token number">0.43058829267318416</span>
degree<span class="token punctuation">:</span> <span class="token number">3</span><span class="token punctuation">;</span> train score<span class="token punctuation">:</span> <span class="token number">0.2755383996826518</span><span class="token punctuation">;</span> mean squared error<span class="token punctuation">:</span> <span class="token number">0.3688679883773196</span>
degree<span class="token punctuation">:</span> <span class="token number">5</span><span class="token punctuation">;</span> train score<span class="token punctuation">:</span> <span class="token number">0.8982707756590037</span><span class="token punctuation">;</span> mean squared error<span class="token punctuation">:</span> <span class="token number">0.051796609130712795</span>
degree<span class="token punctuation">:</span> <span class="token number">10</span><span class="token punctuation">;</span> train score<span class="token punctuation">:</span> <span class="token number">0.9935830575581858</span><span class="token punctuation">;</span> mean squared error<span class="token punctuation">:</span> <span class="token number">0.0032672603337543927</span><span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span></span></code></pre>
<p>从输出结果可以看出，多项式阶数越高，拟合评分越高，均方根误差越小，拟合效果越好。</p>
<p>把不同模型的拟合效果在二维坐标上画出来，可以清楚地看到不同阶数的多项式的拟合效果：</p>
<pre class="line-numbers language-python"><code class="language-python"><span class="token keyword">import</span> matplotlib<span class="token punctuation">.</span>pyplot <span class="token keyword">as</span> plt
<span class="token keyword">from</span> matplotlib<span class="token punctuation">.</span>figure <span class="token keyword">import</span> SubplotParams
plt<span class="token punctuation">.</span>figure<span class="token punctuation">(</span>figsize<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">12</span><span class="token punctuation">,</span><span class="token number">6</span><span class="token punctuation">)</span><span class="token punctuation">,</span>dpi<span class="token operator">=</span><span class="token number">200</span><span class="token punctuation">,</span>subplotpars<span class="token operator">=</span>SubplotParams<span class="token punctuation">(</span>hspace<span class="token operator">=</span><span class="token number">0.3</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">for</span> i<span class="token punctuation">,</span>r <span class="token keyword">in</span> enumerate<span class="token punctuation">(</span>results<span class="token punctuation">)</span><span class="token punctuation">:</span>
    fig <span class="token operator">=</span> plt<span class="token punctuation">.</span>subplot<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span><span class="token number">2</span><span class="token punctuation">,</span>i<span class="token operator">+</span><span class="token number">1</span><span class="token punctuation">)</span>
    plt<span class="token punctuation">.</span>xlim<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">,</span><span class="token number">8</span><span class="token punctuation">)</span>
    plt<span class="token punctuation">.</span>title<span class="token punctuation">(</span><span class="token string">"LinearRegression degree={}"</span><span class="token punctuation">.</span>format<span class="token punctuation">(</span>r<span class="token punctuation">[</span><span class="token string">"degree"</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    plt<span class="token punctuation">.</span>scatter<span class="token punctuation">(</span>X<span class="token punctuation">,</span>Y<span class="token punctuation">,</span>s<span class="token operator">=</span><span class="token number">5</span><span class="token punctuation">,</span>c<span class="token operator">=</span><span class="token string">'b'</span><span class="token punctuation">,</span>alpha<span class="token operator">=</span><span class="token number">0.5</span><span class="token punctuation">)</span>
    plt<span class="token punctuation">.</span>plot<span class="token punctuation">(</span>X<span class="token punctuation">,</span>r<span class="token punctuation">[</span><span class="token string">"model"</span><span class="token punctuation">]</span><span class="token punctuation">.</span>predict<span class="token punctuation">(</span>X<span class="token punctuation">)</span><span class="token punctuation">,</span><span class="token string">'r-'</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>show<span class="token punctuation">(</span><span class="token punctuation">)</span><span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>使用 <code>SubplotParam</code>s 调整了子图的竖直间距，并且使用 <code>subplot()</code> 函数把4个模型的拟合情况都画在同一个图形上。上述代码的输出结果如下图所示：</p>
<p><img src="/medias/loading.gif" data-original="https://cdn.jsdelivr.net/gh/dongzhougu/imageuse1/image-20200630090937011.png" alt="image-20200630092742917"></p>
<p>在[-2π，2π]区间内，10阶多项式对数据拟合得非常好，我们可以试着画出这10阶模型在[-20,20]的区域内的曲线，观察一下该模型的曲线和正弦函数的差异。代码如下：</p>
<pre class="line-numbers language-python"><code class="language-python">plt<span class="token punctuation">.</span>figure<span class="token punctuation">(</span>figsize<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">12</span><span class="token punctuation">,</span><span class="token number">6</span><span class="token punctuation">)</span><span class="token punctuation">,</span>dpi<span class="token operator">=</span><span class="token number">200</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> np<span class="token punctuation">.</span>linspace<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">20</span><span class="token punctuation">,</span><span class="token number">20</span><span class="token punctuation">,</span><span class="token number">2000</span><span class="token punctuation">)</span><span class="token punctuation">.</span>reshape<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
Y <span class="token operator">=</span> np<span class="token punctuation">.</span>sin<span class="token punctuation">(</span>X<span class="token punctuation">)</span><span class="token punctuation">.</span>reshape<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
model_10 <span class="token operator">=</span> results<span class="token punctuation">[</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token string">"model"</span><span class="token punctuation">]</span>
plt<span class="token punctuation">.</span>xlim<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">20</span><span class="token punctuation">,</span><span class="token number">20</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>ylim<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">2</span><span class="token punctuation">,</span><span class="token number">2</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>plot<span class="token punctuation">(</span>X<span class="token punctuation">,</span>Y<span class="token punctuation">,</span><span class="token string">'b-'</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>plot<span class="token punctuation">(</span>X<span class="token punctuation">,</span>model_10<span class="token punctuation">.</span>predict<span class="token punctuation">(</span>X<span class="token punctuation">)</span><span class="token punctuation">,</span><span class="token string">'r-'</span><span class="token punctuation">)</span>
dot1 <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">2</span><span class="token operator">*</span>np<span class="token punctuation">.</span>pi<span class="token punctuation">,</span><span class="token number">0</span><span class="token punctuation">]</span>
dot2 <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">2</span><span class="token operator">*</span>np<span class="token punctuation">.</span>pi<span class="token punctuation">,</span><span class="token number">0</span><span class="token punctuation">]</span>
plt<span class="token punctuation">.</span>scatter<span class="token punctuation">(</span>dot1<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span>dot1<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span>s<span class="token operator">=</span><span class="token number">50</span><span class="token punctuation">,</span>c<span class="token operator">=</span><span class="token string">'r'</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>scatter<span class="token punctuation">(</span>dot2<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span>dot2<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span>s<span class="token operator">=</span><span class="token number">50</span><span class="token punctuation">,</span>c<span class="token operator">=</span><span class="token string">'r'</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>show<span class="token punctuation">(</span><span class="token punctuation">)</span><span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>上述代码的输出结果如下图：</p>
<p><img src="/medias/loading.gif" data-original="https://cdn.jsdelivr.net/gh/dongzhougu/imageuse1/image-20200630100257358.png" alt="image-20200630093109778"></p>
<p>从图中可以看出，10阶多项式模型只有在区间[-2π,2π]之间对正弦曲线拟合较好，在此区间以外，两者相差甚远。此案例告诉我们，每个模型都有自己的适用范围，在满足适用范围的基本前提下，要尽可能寻找拟合程度最高的模型来使用。</p>
<h2 id="示例2：预测房价"><a href="#示例2：预测房价" class="headerlink" title="示例2：预测房价"></a>示例2：预测房价</h2><p>本节使用 <code>scikit-learn</code> 自带的波士顿房价数据来训练模型，然后用模型来预测房价。</p>
<h3 id="输入特征"><a href="#输入特征" class="headerlink" title="输入特征"></a>输入特征</h3><p>房价和哪些因素有关？很多人可能对这个问题特别敏感，随时可以列出很多，如房子面子、房子地理位置、周边教育资源、周边商业资源、房子朝向、年限、小区情况等。在 <code>scikit-learn</code>的波士顿房价数据集里，它总共收集了13个特征，具体如下：</p>
<ul>
<li><p>CRIM：城镇人均犯罪率。</p>
</li>
<li><p>ZN：城镇超过25000平方英尺的住宅区域的占地比例。</p>
</li>
<li><p>INDUS：城镇非零售用地占地比例。</p>
</li>
<li><p>CHAS：是否靠近河边，1为靠近，0为远离。</p>
</li>
<li><p>NOX：一氧化氮浓度</p>
</li>
<li><p>RM：每套房产的平均房间个数。</p>
</li>
<li><p>AGE：在1940年之前就盖好，且业主自住的房子的比例。</p>
</li>
<li><p>DIS：与波士顿市中心的距离。</p>
</li>
<li><p>RAD：周边高速公路的便利性指数。</p>
</li>
<li><p>TAX：每10000美元的财产税率。</p>
</li>
<li><p>PTRATIO：小学老师的比例。</p>
</li>
<li><p>B：城镇黑人的比例。</p>
</li>
<li><p>LSTAT：地位较低的人口比例。</p>
</li>
</ul>
<p>从这些指标里可以看到中美指标的一些差异。当然，这个数据是在1993年之前收集的，可能和现在会有差异。不要小看了这些指标，实际上一个模型的好坏和输入特征的选择关系密切。大家可以思考一下，如果要在中国预测房价，你会收集哪些特征数据？这些特征数据的可获得性如何？收集成本多高？</p>
<p>先导入数据：</p>
<pre class="line-numbers language-python"><code class="language-python"><span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>datasets <span class="token keyword">import</span> load_boston
boston <span class="token operator">=</span> load_boston<span class="token punctuation">(</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> boston<span class="token punctuation">.</span>data
y <span class="token operator">=</span> boston<span class="token punctuation">.</span>target
<span class="token keyword">print</span><span class="token punctuation">(</span>X<span class="token punctuation">.</span>shape<span class="token punctuation">)</span>  <span class="token comment" spellcheck="true"># (506, 13)</span><span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>表明这个数据集有506个样本，每个样本有13个特征。整个训练样本放在一个506*13的矩阵里。可以通过X[0]来查看一个样本数据：</p>
<pre class="line-numbers language-python"><code class="language-python"><span class="token keyword">print</span><span class="token punctuation">(</span>X<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
array<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">6.320e-03</span><span class="token punctuation">,</span> <span class="token number">1.800e+01</span><span class="token punctuation">,</span> <span class="token number">2.310e+00</span><span class="token punctuation">,</span> <span class="token number">0.000e+00</span><span class="token punctuation">,</span> <span class="token number">5.380e-01</span><span class="token punctuation">,</span> <span class="token number">6.575e+00</span><span class="token punctuation">,</span>
       <span class="token number">6.520e+01</span><span class="token punctuation">,</span> <span class="token number">4.090e+00</span><span class="token punctuation">,</span> <span class="token number">1.000e+00</span><span class="token punctuation">,</span> <span class="token number">2.960e+02</span><span class="token punctuation">,</span> <span class="token number">1.530e+01</span><span class="token punctuation">,</span> <span class="token number">3.969e+02</span><span class="token punctuation">,</span>
       <span class="token number">4.980e+00</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span></span></code></pre>
<p>还可以通过 <code>boston.features_names</code> 来查看这些特征的标签：</p>
<pre class="line-numbers language-python"><code class="language-python"><span class="token keyword">print</span><span class="token punctuation">(</span>boston<span class="token punctuation">.</span>feature_names<span class="token punctuation">)</span><span aria-hidden="true" class="line-numbers-rows"><span></span></span></code></pre>
<p>输出如下：</p>
<pre class="line-numbers language-python"><code class="language-python">array<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token string">'CRIM'</span><span class="token punctuation">,</span> <span class="token string">'ZN'</span><span class="token punctuation">,</span> <span class="token string">'INDUS'</span><span class="token punctuation">,</span> <span class="token string">'CHAS'</span><span class="token punctuation">,</span> <span class="token string">'NOX'</span><span class="token punctuation">,</span> <span class="token string">'RM'</span><span class="token punctuation">,</span> <span class="token string">'AGE'</span><span class="token punctuation">,</span> <span class="token string">'DIS'</span><span class="token punctuation">,</span> <span class="token string">'RAD'</span><span class="token punctuation">,</span>
       <span class="token string">'TAX'</span><span class="token punctuation">,</span> <span class="token string">'PTRATIO'</span><span class="token punctuation">,</span> <span class="token string">'B'</span><span class="token punctuation">,</span> <span class="token string">'LSTAT'</span><span class="token punctuation">]</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span><span class="token string">'&lt;U7'</span><span class="token punctuation">)</span><span aria-hidden="true" class="line-numbers-rows"><span></span><span></span></span></code></pre>
<p>我们可以把特征和数值对应起来，观察一下数据。</p>
<h3 id="模型训练"><a href="#模型训练" class="headerlink" title="模型训练"></a>模型训练</h3><p>在 <code>scikit-learn</code> 里，<code>LinearRegression</code> 类实现了线性回归算法。在对模型进行训练之前，我们需要先把数据集分成两份，以便评估算法的准确性。</p>
<pre class="line-numbers language-python"><code class="language-python"><span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>model_selection <span class="token keyword">import</span> train_test_split
X_train<span class="token punctuation">,</span>X_test<span class="token punctuation">,</span>y_train<span class="token punctuation">,</span>y_test<span class="token operator">=</span>train_test_split<span class="token punctuation">(</span>X<span class="token punctuation">,</span>y<span class="token punctuation">,</span>test_size<span class="token operator">=</span><span class="token number">0.2</span><span class="token punctuation">,</span>random_state<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span><span aria-hidden="true" class="line-numbers-rows"><span></span><span></span></span></code></pre>
<p>由于数据量比较小，我们只选了20%的样本来作为测试数据集。接着，训练模型并测试模型的准确性评分：</p>
<pre class="line-numbers language-python"><code class="language-python"><span class="token keyword">import</span> time
<span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>linear_model <span class="token keyword">import</span> LinearRegression
model <span class="token operator">=</span> LinearRegression<span class="token punctuation">(</span><span class="token punctuation">)</span>
start <span class="token operator">=</span> time<span class="token punctuation">.</span>process_time<span class="token punctuation">(</span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span>fit<span class="token punctuation">(</span>X_train<span class="token punctuation">,</span>y_train<span class="token punctuation">)</span>
train_score <span class="token operator">=</span> model<span class="token punctuation">.</span>score<span class="token punctuation">(</span>X_train<span class="token punctuation">,</span>y_train<span class="token punctuation">)</span>
test_score <span class="token operator">=</span> model<span class="token punctuation">.</span>score<span class="token punctuation">(</span>X_test<span class="token punctuation">,</span>y_test<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"elaspe:{0:.6f};train_score:{1:0.6f};test_score:{2:.6f}"</span>
      <span class="token punctuation">.</span>format<span class="token punctuation">(</span>time<span class="token punctuation">.</span>process_time<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token operator">-</span>start<span class="token punctuation">,</span>train_score<span class="token punctuation">,</span>test_score<span class="token punctuation">)</span><span class="token punctuation">)</span><span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>我们顺便统计了模型的训练时间，除此之外，统计模型对训练样本的准确性得分（即对训练样本拟合的好坏程度）<code>train_score</code>，还测试了模型对测试样本的得分test_score。运行结果如下：</p>
<pre class="line-numbers language-python"><code class="language-python">elaspe<span class="token punctuation">:</span><span class="token number">0.000000</span><span class="token punctuation">;</span>train_score<span class="token punctuation">:</span><span class="token number">0.723941</span><span class="token punctuation">;</span>test_score<span class="token punctuation">:</span><span class="token number">0.795262</span><span aria-hidden="true" class="line-numbers-rows"><span></span></span></code></pre>
<p>从得分情况来看，模型的拟合效果一般，还有没有办法来优化模型的拟合效果呢？</p>
<h3 id="模型优化-1"><a href="#模型优化-1" class="headerlink" title="模型优化"></a>模型优化</h3><p>首先观察一下数据，特征数据的范围相差比较大，最小的在$10^{-3}$级别，而最大的在$10^{2}$级别，看来我们需要先把数据进行归一化处理。归一化处理最简单的方式是，创建线性回归模型时增加normalize=True参数：</p>
<pre class="line-numbers language-python"><code class="language-python">model <span class="token operator">=</span> LinearRegression<span class="token punctuation">(</span>normalize<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span><span aria-hidden="true" class="line-numbers-rows"><span></span></span></code></pre>
<p>当然，数据归一化处理只会加快算法收敛速度，优化算法训练的效率，无法提升算法的准确性。</p>
<p>怎么样优化模型的准确性呢？我们回到训练分数上来，可以观察到模型针对训练样本的评分比较低（train_score:0.723941），即模型对训练样本的拟合成本比较高，这是一个典型的欠拟合现象。回忆我们之前介绍的优化欠拟合模型的方法，一是挖掘更多的输入特征，而是增加多项式特征。在我们这个例子里，通过使用低成本的方案——即增加多项式特征来看能否优化模型的性能。增加多项式特征，其实就是增加模型的复杂度。</p>
<p>我们使用之前创建多项式模型的函数 <code>polynomial_model</code>，接着，我们使用二阶多项式来拟合数据：</p>
<pre class="line-numbers language-python"><code class="language-python">model <span class="token operator">=</span> polynomial_model<span class="token punctuation">(</span>degree<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span>
start <span class="token operator">=</span> time<span class="token punctuation">.</span>process_time<span class="token punctuation">(</span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span>fit<span class="token punctuation">(</span>X_train<span class="token punctuation">,</span> y_train<span class="token punctuation">)</span>
train_score <span class="token operator">=</span> model<span class="token punctuation">.</span>score<span class="token punctuation">(</span>X_train<span class="token punctuation">,</span> y_train<span class="token punctuation">)</span>
test_score <span class="token operator">=</span> model<span class="token punctuation">.</span>score<span class="token punctuation">(</span>X_test<span class="token punctuation">,</span> y_test<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"elaspe:{0:.6f};train_score:{1:0.6f};test_score:{2:.6f}"</span>
      <span class="token punctuation">.</span>format<span class="token punctuation">(</span>time<span class="token punctuation">.</span>process_time<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">-</span> start<span class="token punctuation">,</span> train_score<span class="token punctuation">,</span> test_score<span class="token punctuation">)</span><span class="token punctuation">)</span><span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>输出结果是：</p>
<pre class="line-numbers language-python"><code class="language-python">elaspe<span class="token punctuation">:</span><span class="token number">0.078125</span><span class="token punctuation">;</span>train_score<span class="token punctuation">:</span><span class="token number">0.930547</span><span class="token punctuation">;</span>test_score<span class="token punctuation">:</span><span class="token number">0.860049</span><span aria-hidden="true" class="line-numbers-rows"><span></span></span></code></pre>
<p>训练样本分数和测试分数都提高了，看来模型确实得到了优化。我们可以把多项式改为3阶看一下效果：</p>
<pre class="line-numbers language-python"><code class="language-python">elaspe<span class="token punctuation">:</span><span class="token number">0.093750</span><span class="token punctuation">;</span>train_score<span class="token punctuation">:</span><span class="token number">1.000000</span><span class="token punctuation">;</span>test_score<span class="token punctuation">:</span><span class="token operator">-</span><span class="token number">105.548323</span><span aria-hidden="true" class="line-numbers-rows"><span></span></span></code></pre>
<p>改为3阶多项式后，针对训练样本的分数达到了1，而针对测试样本的分数确实负数，说明这个模型过拟合了。</p>
<p>思考：我们总共有13个输入特征，从一阶多项式变为二阶多项式，输入特征个数增加了多少个？<br> 参考：二阶多项式共有：13个单一的特征，$C_{13}^{2}=78$ 个两两配对的特征，13个各自平方的特征，共计104个特征。比一阶多项式的13个特征增加了91个特征。</p>
<h3 id="学习曲线"><a href="#学习曲线" class="headerlink" title="学习曲线"></a>学习曲线</h3><p>更好的方法是画出学习曲线，这样对模型的状态以及优化的方向就一目了然。</p>
<pre class="line-numbers language-python"><code class="language-python"><span class="token keyword">import</span> matplotlib<span class="token punctuation">.</span>pyplot <span class="token keyword">as</span> plt
<span class="token keyword">from</span> utils <span class="token keyword">import</span> plot_learning_curve
<span class="token keyword">from</span> sklearn<span class="token punctuation">.</span>model_selection <span class="token keyword">import</span> ShuffleSplit
cv <span class="token operator">=</span> ShuffleSplit<span class="token punctuation">(</span>n_splits<span class="token operator">=</span><span class="token number">10</span><span class="token punctuation">,</span>test_size<span class="token operator">=</span><span class="token number">0.2</span><span class="token punctuation">,</span>random_state<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>figure<span class="token punctuation">(</span>figsize<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">18</span><span class="token punctuation">,</span><span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">,</span>dpi<span class="token operator">=</span><span class="token number">200</span><span class="token punctuation">)</span>
title <span class="token operator">=</span> <span class="token string">'Learning Curves (degree={0})'</span>
degrees <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">2</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">]</span>
start <span class="token operator">=</span> time<span class="token punctuation">.</span>process_time<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">for</span> i <span class="token keyword">in</span> range<span class="token punctuation">(</span>len<span class="token punctuation">(</span>degrees<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    plt<span class="token punctuation">.</span>subplot<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">,</span>i<span class="token operator">+</span><span class="token number">1</span><span class="token punctuation">)</span>
    plot_learning_curve<span class="token punctuation">(</span>plt<span class="token punctuation">,</span>polynomial_model<span class="token punctuation">(</span>degrees<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>title<span class="token punctuation">.</span>format<span class="token punctuation">(</span>degrees<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                        X<span class="token punctuation">,</span>y<span class="token punctuation">,</span>ylim<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">0.01</span><span class="token punctuation">,</span><span class="token number">1.01</span><span class="token punctuation">)</span><span class="token punctuation">,</span>cv<span class="token operator">=</span>cv<span class="token punctuation">)</span>
    <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'elaspe:{0:.6f}'</span><span class="token punctuation">.</span>format<span class="token punctuation">(</span>time<span class="token punctuation">.</span>process_time<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token operator">-</span>start<span class="token punctuation">)</span><span class="token punctuation">)</span><span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>输出如下：</p>
<p><img src="/medias/loading.gif" data-original="https://cdn.jsdelivr.net/gh/dongzhougu/imageuse1/image-20200630092742917.png" alt="image-20200630100257358"></p>
<p>从学习曲线中可以看出，一阶多项式欠拟合，因为针对训练样本的分数比较低；而三阶多项式过拟合，因为针对训练样本的分数达到1，却看不到交叉验证数据集的分数。针对二阶多项式拟合的情况，虽然比一阶多项式的效果好，但从图中可以明显地看出来，针对训练数据集的分数和针对交叉验证数据集的分数之间的间隔比较大，这说明训练样本数量不够，我们应该去采集更多的数据，以提高模型的准确性。</p>
<h2 id="拓展阅读"><a href="#拓展阅读" class="headerlink" title="拓展阅读"></a>拓展阅读</h2><p>本节内容涉及到较多的数学知识，特别是矩阵和偏导数运算法则。如果阅读起来有困难，可以先跳过。如果有一定数学基础，这些知识对理解算法的实现细节及算法的效率有较大的帮助。</p>
<h3 id="公式推导的数学基础"><a href="#公式推导的数学基础" class="headerlink" title="公式推导的数学基础"></a>公式推导的数学基础</h3><p>AI的数学基础最主要的是高等数学、线性代数、概率论与数理统计这三门课程。下面是简易的入门文章供参考</p>
<ul>
<li>高等数学 <a href="https://zhuanlan.zhihu.com/p/36311622" target="_blank" rel="noopener">https://zhuanlan.zhihu.com/p/36311622</a></li>
<li>线性代数 <a href="https://zhuanlan.zhihu.com/p/36584206" target="_blank" rel="noopener">https://zhuanlan.zhihu.com/p/36584206</a></li>
<li>概率论与数理统计 <a href="https://zhuanlan.zhihu.com/p/36584335" target="_blank" rel="noopener">https://zhuanlan.zhihu.com/p/36584335</a></li>
</ul>
<h3 id="随机梯度下降算法"><a href="#随机梯度下降算法" class="headerlink" title="随机梯度下降算法"></a>随机梯度下降算法</h3><p>本章介绍的梯度下降算法迭代公式称为批量梯度下降算法（Batch Gradient Descent，简称BGD），用它对参数进行一次迭代运算，需要遍历所有的训练数据集。当训练数据集比较大时，其算法的效率会比较低。考虑另外一个算法：<br>$$<br>\theta_{j}=\theta_{j}-\alpha\left(\left(h\left(x^{(i)}\right)-y^{(i)}\right) x_{j}^{(i)}\right)<br>$$<br> 这个算法的关键点是把累加器去掉，不去遍历所有的数据集，而是改成每次随机地从训练数据集中取一个数据进行参数迭代计算，这就是随机梯度下降算法（Stochastic Gradient Descent，简称SGD）。随机梯度下降算法可以大大提高模型训练的效率。</p>
<h3 id="正规方程"><a href="#正规方程" class="headerlink" title="正规方程"></a>正规方程</h3><p>梯度下降算法通过不断地迭代，从而不停地逼近成本函数的最小值来求解模型的参数。另外一个方法是直接计算成本函数的微分，令微分算子为0，求解这个方程，即可得到线性回归的解。<br> 线性回归算法的损失函数：<br>$$<br>J(\theta)=\frac{1}{2 m} \sum_{i=0}^{n}\left(h_{\theta}\left(x^{(i)}\right)-y^{(i)}\right)^{2}<br>$$<br>成本函数的“斜率”为0的点，即为模型参数的解。令$\frac{\partial}{\partial \theta} J(\theta)=0$，求解这个方程最终可以得到模型参数：<br>$$<br>\theta=\left(X^{T} X\right)^{-1} X^{T} y<br>$$<br>这就是我们的正规方程。它通过矩阵运算，直接从训练样本里求出参数θ的值。其中X为训练样本的矩阵形式，它是m×n的矩阵，y是训练样本的结果数据，它是个m维列向量。方程求解过程可参阅<a href="https://baike.baidu.com/item/%E6%AD%A3%E8%A7%84%E6%96%B9%E7%A8%8B/10001812?fr=aladdin" target="_blank" rel="noopener">百度百科</a></p>
<script>
        document.querySelectorAll('.github-emoji')
          .forEach(el => {
            if (!el.dataset.src) { return; }
            const img = document.createElement('img');
            img.style = 'display:none !important;';
            img.src = el.dataset.src;
            img.addEventListener('error', () => {
              img.remove();
              el.style.color = 'inherit';
              el.style.backgroundImage = 'none';
              el.style.background = 'none';
            });
            img.addEventListener('load', () => {
              img.remove();
            });
            document.body.appendChild(img);
          });
      </script>
            </div>
            <hr/>

            

    <div class="reprint" id="reprint-statement">
        
            <div class="reprint__author">
                <span class="reprint-meta" style="font-weight: bold;">
                    <i class="fas fa-user">
                        文章作者:
                    </i>
                </span>
                <span class="reprint-info">
                    <a href="/about" rel="external nofollow noreferrer">DongZhou</a>
                </span>
            </div>
            <div class="reprint__type">
                <span class="reprint-meta" style="font-weight: bold;">
                    <i class="fas fa-link">
                        文章链接:
                    </i>
                </span>
                <span class="reprint-info">
                    <a href="https://dongzhougu.github.io/2020/06/30/scikit-learn-xi-lie-san-xian-xing-hui-gui/">https://dongzhougu.github.io/2020/06/30/scikit-learn-xi-lie-san-xian-xing-hui-gui/</a>
                </span>
            </div>
            <div class="reprint__notice">
                <span class="reprint-meta" style="font-weight: bold;">
                    <i class="fas fa-copyright">
                        版权声明:
                    </i>
                </span>
                <span class="reprint-info">
                    本博客所有文章除特別声明外，均采用
                    <a href="https://creativecommons.org/licenses/by/4.0/deed.zh" rel="external nofollow noreferrer" target="_blank">CC BY 4.0</a>
                    许可协议。转载请注明来源
                    <a href="/about" target="_blank">DongZhou</a>
                    !
                </span>
            </div>
        
    </div>

    <script async defer>
      document.addEventListener("copy", function (e) {
        let toastHTML = '<span>复制成功，请遵循本文的转载规则</span><button class="btn-flat toast-action" onclick="navToReprintStatement()" style="font-size: smaller">查看</a>';
        M.toast({html: toastHTML})
      });

      function navToReprintStatement() {
        $("html, body").animate({scrollTop: $("#reprint-statement").offset().top - 80}, 800);
      }
    </script>



            <div class="tag_share" style="display: block;">
                <div class="post-meta__tag-list" style="display: inline-block;">
                    
                        <div class="article-tag">
                            
                                <a href="/tags/%E5%9F%BA%E7%A1%80%E7%9F%A5%E8%AF%86/">
                                    <span class="chip bg-color">基础知识</span>
                                </a>
                            
                                <a href="/tags/ML%E7%AE%97%E6%B3%95/">
                                    <span class="chip bg-color">ML算法</span>
                                </a>
                            
                        </div>
                    
                </div>
                <div class="post_share" style="zoom: 80%; width: fit-content; display: inline-block; float: right; margin: -0.15rem 0;">
                    <link rel="stylesheet" type="text/css" href="/libs/share/css/share.min.css">
<div id="article-share">

    
    <div class="social-share" data-sites="google,qq,qzone,wechat,weibo,douban,linkedin" data-wechat-qrcode-helper="<p>微信扫一扫即可分享！</p>"></div>
    <script src="/libs/share/js/social-share.min.js"></script>
    

    

</div>

                </div>
            </div>
            
                <style>
    #reward {
        margin: 40px 0;
        text-align: center;
    }

    #reward .reward-link {
        font-size: 1.4rem;
        line-height: 38px;
    }

    #reward .btn-floating:hover {
        box-shadow: 0 6px 12px rgba(0, 0, 0, 0.2), 0 5px 15px rgba(0, 0, 0, 0.2);
    }

    #rewardModal {
        width: 320px;
        height: 350px;
    }

    #rewardModal .reward-title {
        margin: 15px auto;
        padding-bottom: 5px;
    }

    #rewardModal .modal-content {
        padding: 10px;
    }

    #rewardModal .close {
        position: absolute;
        right: 15px;
        top: 15px;
        color: rgba(0, 0, 0, 0.5);
        font-size: 1.3rem;
        line-height: 20px;
        cursor: pointer;
    }

    #rewardModal .close:hover {
        color: #ef5350;
        transform: scale(1.3);
        -moz-transform:scale(1.3);
        -webkit-transform:scale(1.3);
        -o-transform:scale(1.3);
    }

    #rewardModal .reward-tabs {
        margin: 0 auto;
        width: 210px;
    }

    .reward-tabs .tabs {
        height: 38px;
        margin: 10px auto;
        padding-left: 0;
    }

    .reward-content ul {
        padding-left: 0 !important;
    }

    .reward-tabs .tabs .tab {
        height: 38px;
        line-height: 38px;
    }

    .reward-tabs .tab a {
        color: #fff;
        background-color: #ccc;
    }

    .reward-tabs .tab a:hover {
        background-color: #ccc;
        color: #fff;
    }

    .reward-tabs .wechat-tab .active {
        color: #fff !important;
        background-color: #22AB38 !important;
    }

    .reward-tabs .alipay-tab .active {
        color: #fff !important;
        background-color: #019FE8 !important;
    }

    .reward-tabs .reward-img {
        width: 210px;
        height: 210px;
    }
</style>

<div id="reward">
    <a href="#rewardModal" class="reward-link modal-trigger btn-floating btn-medium waves-effect waves-light red">赏</a>

    <!-- Modal Structure -->
    <div id="rewardModal" class="modal">
        <div class="modal-content">
            <a class="close modal-close"><i class="fas fa-times"></i></a>
            <h4 class="reward-title">你的赏识是我前进的动力</h4>
            <div class="reward-content">
                <div class="reward-tabs">
                    <ul class="tabs row">
                        <li class="tab col s6 alipay-tab waves-effect waves-light"><a href="#alipay">支付宝</a></li>
                        <li class="tab col s6 wechat-tab waves-effect waves-light"><a href="#wechat">微 信</a></li>
                    </ul>
                    <div id="alipay">
                        <img src="https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/medias/reward/alipay.jpg" class="reward-img" alt="支付宝打赏二维码">
                    </div>
                    <div id="wechat">
                        <img src="https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/medias/reward/wechat.png" class="reward-img" alt="微信打赏二维码">
                    </div>
                </div>
            </div>
        </div>
    </div>
</div>

<script>
    $(function () {
        $('.tabs').tabs();
    });
</script>

            
        </div>
    </div>

    

    

    

    

    
        <style>
    .valine-card {
        margin: 1.5rem auto;
    }

    .valine-card .card-content {
        padding: 20px 20px 5px 20px;
    }

    #vcomments textarea {
        box-sizing: border-box;
        background: url("/medias/comment_bg.png") 100% 100% no-repeat;
    }

    #vcomments p {
        margin: 2px 2px 10px;
        font-size: 1.05rem;
        line-height: 1.78rem;
    }

    #vcomments blockquote p {
        text-indent: 0.2rem;
    }

    #vcomments a {
        padding: 0 2px;
        color: #4cbf30;
        font-weight: 500;
        text-decoration: none;
    }

    #vcomments img {
        max-width: 100%;
        height: auto;
        cursor: pointer;
    }

    #vcomments ol li {
        list-style-type: decimal;
    }

    #vcomments ol,
    ul {
        display: block;
        padding-left: 2em;
        word-spacing: 0.05rem;
    }

    #vcomments ul li,
    ol li {
        display: list-item;
        line-height: 1.8rem;
        font-size: 1rem;
    }

    #vcomments ul li {
        list-style-type: disc;
    }

    #vcomments ul ul li {
        list-style-type: circle;
    }

    #vcomments table, th, td {
        padding: 12px 13px;
        border: 1px solid #dfe2e5;
    }

    #vcomments table, th, td {
        border: 0;
    }

    table tr:nth-child(2n), thead {
        background-color: #fafafa;
    }

    #vcomments table th {
        background-color: #f2f2f2;
        min-width: 80px;
    }

    #vcomments table td {
        min-width: 80px;
    }

    #vcomments h1 {
        font-size: 1.85rem;
        font-weight: bold;
        line-height: 2.2rem;
    }

    #vcomments h2 {
        font-size: 1.65rem;
        font-weight: bold;
        line-height: 1.9rem;
    }

    #vcomments h3 {
        font-size: 1.45rem;
        font-weight: bold;
        line-height: 1.7rem;
    }

    #vcomments h4 {
        font-size: 1.25rem;
        font-weight: bold;
        line-height: 1.5rem;
    }

    #vcomments h5 {
        font-size: 1.1rem;
        font-weight: bold;
        line-height: 1.4rem;
    }

    #vcomments h6 {
        font-size: 1rem;
        line-height: 1.3rem;
    }

    #vcomments p {
        font-size: 1rem;
        line-height: 1.5rem;
    }

    #vcomments hr {
        margin: 12px 0;
        border: 0;
        border-top: 1px solid #ccc;
    }

    #vcomments blockquote {
        margin: 15px 0;
        border-left: 5px solid #42b983;
        padding: 1rem 0.8rem 0.3rem 0.8rem;
        color: #666;
        background-color: rgba(66, 185, 131, .1);
    }

    #vcomments pre {
        font-family: monospace, monospace;
        padding: 1.2em;
        margin: .5em 0;
        background: #272822;
        overflow: auto;
        border-radius: 0.3em;
        tab-size: 4;
    }

    #vcomments code {
        font-family: monospace, monospace;
        padding: 1px 3px;
        font-size: 0.92rem;
        color: #e96900;
        background-color: #f8f8f8;
        border-radius: 2px;
    }

    #vcomments pre code {
        font-family: monospace, monospace;
        padding: 0;
        color: #e8eaf6;
        background-color: #272822;
    }

    #vcomments pre[class*="language-"] {
        padding: 1.2em;
        margin: .5em 0;
    }

    #vcomments code[class*="language-"],
    pre[class*="language-"] {
        color: #e8eaf6;
    }

    #vcomments [type="checkbox"]:not(:checked), [type="checkbox"]:checked {
        position: inherit;
        margin-left: -1.3rem;
        margin-right: 0.4rem;
        margin-top: -1px;
        vertical-align: middle;
        left: unset;
        visibility: visible;
    }

    #vcomments b,
    strong {
        font-weight: bold;
    }

    #vcomments dfn {
        font-style: italic;
    }

    #vcomments small {
        font-size: 85%;
    }

    #vcomments cite {
        font-style: normal;
    }

    #vcomments mark {
        background-color: #fcf8e3;
        padding: .2em;
    }

    #vcomments table, th, td {
        padding: 12px 13px;
        border: 1px solid #dfe2e5;
    }

    table tr:nth-child(2n), thead {
        background-color: #fafafa;
    }

    #vcomments table th {
        background-color: #f2f2f2;
        min-width: 80px;
    }

    #vcomments table td {
        min-width: 80px;
    }

    #vcomments [type="checkbox"]:not(:checked), [type="checkbox"]:checked {
        position: inherit;
        margin-left: -1.3rem;
        margin-right: 0.4rem;
        margin-top: -1px;
        vertical-align: middle;
        left: unset;
        visibility: visible;
    }
</style>

<div class="card valine-card" data-aos="fade-up">
    <div class="comment_headling" style="font-size: 20px; font-weight: 700; position: relative; padding-left: 20px; top: 15px; padding-bottom: 5px;">
        <i class="fas fa-comments fa-fw" aria-hidden="true"></i>
        <span>评论</span>
    </div>
    <div id="vcomments" class="card-content" style="display: grid">
    </div>
</div>

<script src="/libs/valine/av-min.js"></script>
<script src="https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/valine/Valine.min.js"></script>
<script>
    new Valine({
        el: '#vcomments',
        appId: 'RPCuj0HNm1eqAREO6c5T7nSJ-gzGzoHsz',
        appKey: 'laCdQbWLFWOWdkXVM3RxoXGe',
        notify: 'false' === 'true',
        verify: 'false' === 'true',
        visitor: 'true' === 'true',
        avatar: 'mm',
        pageSize: '10',
        lang: 'zh-cn',
        placeholder: '快来留言吧'
    });
</script>

    

    

    

<article id="prenext-posts" class="prev-next articles">
    <div class="row article-row">
        
        <div class="article col s12 m6" data-aos="fade-up">
            <div class="article-badge left-badge text-color">
                <i class="fas fa-chevron-left"></i>&nbsp;上一篇</div>
            <div class="card">
                <a href="/2020/06/30/scikit-learn-xi-lie-si-luo-ji-hui-gui/">
                    <div class="card-image">
                        
                        
                        <img src="https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/medias/featureimages/9.jpg" class="responsive-img" alt="scikit-learn系列四：逻辑回归">
                        
                        <span class="card-title">scikit-learn系列四：逻辑回归</span>
                    </div>
                </a>
                <div class="card-content article-content">
                    <div class="summary block-with-text">
                        
                            实现逻辑回归算法，原理解释+癌症检测案例
                        
                    </div>
                    <div class="publish-info">
                        <span class="publish-date">
                            <i class="far fa-clock fa-fw icon-date"></i>2020-06-30
                        </span>
                        <span class="publish-author">
                            
                            <i class="fas fa-bookmark fa-fw icon-category"></i>
                            
                            <a href="/categories/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/" class="post-category">
                                    机器学习
                                </a>
                            
                            
                        </span>
                    </div>
                </div>
                
                <div class="card-action article-tags">
                    
                    <a href="/tags/%E5%9F%BA%E7%A1%80%E7%9F%A5%E8%AF%86/">
                        <span class="chip bg-color">基础知识</span>
                    </a>
                    
                    <a href="/tags/ML%E7%AE%97%E6%B3%95/">
                        <span class="chip bg-color">ML算法</span>
                    </a>
                    
                </div>
                
            </div>
        </div>
        
        
        <div class="article col s12 m6" data-aos="fade-up">
            <div class="article-badge right-badge text-color">
                下一篇&nbsp;<i class="fas fa-chevron-right"></i>
            </div>
            <div class="card">
                <a href="/2020/06/29/scikit-learn-xi-lie-er-k-jin-lin/">
                    <div class="card-image">
                        
                        
                        <img src="https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/medias/featureimages/6.jpg" class="responsive-img" alt="scikit-learn系列二：K-近邻">
                        
                        <span class="card-title">scikit-learn系列二：K-近邻</span>
                    </div>
                </a>
                <div class="card-content article-content">
                    <div class="summary block-with-text">
                        
                            使用scikit-learn实现K-近邻算法，多案例
                        
                    </div>
                    <div class="publish-info">
                            <span class="publish-date">
                                <i class="far fa-clock fa-fw icon-date"></i>2020-06-29
                            </span>
                        <span class="publish-author">
                            
                            <i class="fas fa-bookmark fa-fw icon-category"></i>
                            
                            <a href="/categories/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/" class="post-category">
                                    机器学习
                                </a>
                            
                            
                        </span>
                    </div>
                </div>
                
                <div class="card-action article-tags">
                    
                    <a href="/tags/%E5%9F%BA%E7%A1%80%E7%9F%A5%E8%AF%86/">
                        <span class="chip bg-color">基础知识</span>
                    </a>
                    
                    <a href="/tags/ML%E7%AE%97%E6%B3%95/">
                        <span class="chip bg-color">ML算法</span>
                    </a>
                    
                </div>
                
            </div>
        </div>
        
    </div>
</article>

</div>



<!-- 代码块功能依赖 -->
<script type="text/javascript" src="/libs/codeBlock/codeBlockFuction.js"></script>

<!-- 代码语言 -->

<script type="text/javascript" src="/libs/codeBlock/codeLang.js"></script>


<!-- 代码块复制 -->

<script type="text/javascript" src="/libs/codeBlock/codeCopy.js"></script>


<!-- 代码块收缩 -->

<script type="text/javascript" src="/libs/codeBlock/codeShrink.js"></script>


<!-- 代码块折行 -->

<style type="text/css">
code[class*="language-"], pre[class*="language-"] { white-space: pre !important; }
</style>


    </div>
    <div id="toc-aside" class="expanded col l3 hide-on-med-and-down">
        <div class="toc-widget">
            <div class="toc-title"><i class="far fa-list-alt"></i>&nbsp;&nbsp;目录</div>
            <div id="toc-content"></div>
        </div>
    </div>
</div>

<!-- TOC 悬浮按钮. -->

<div id="floating-toc-btn" class="hide-on-med-and-down">
    <a class="btn-floating btn-large bg-color">
        <i class="fas fa-list-ul"></i>
    </a>
</div>


<script src="https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/tocbot/tocbot.min.js"></script>
<script>
    $(function () {
        tocbot.init({
            tocSelector: '#toc-content',
            contentSelector: '#articleContent',
            headingsOffset: -($(window).height() * 0.4 - 45),
            collapseDepth: Number('0'),
            headingSelector: 'h2, h3, h4'
        });

        // modify the toc link href to support Chinese.
        let i = 0;
        let tocHeading = 'toc-heading-';
        $('#toc-content a').each(function () {
            $(this).attr('href', '#' + tocHeading + (++i));
        });

        // modify the heading title id to support Chinese.
        i = 0;
        $('#articleContent').children('h2, h3, h4').each(function () {
            $(this).attr('id', tocHeading + (++i));
        });

        // Set scroll toc fixed.
        let tocHeight = parseInt($(window).height() * 0.4 - 64);
        let $tocWidget = $('.toc-widget');
        $(window).scroll(function () {
            let scroll = $(window).scrollTop();
            /* add post toc fixed. */
            if (scroll > tocHeight) {
                $tocWidget.addClass('toc-fixed');
            } else {
                $tocWidget.removeClass('toc-fixed');
            }
        });

        
        /* 修复文章卡片 div 的宽度. */
        let fixPostCardWidth = function (srcId, targetId) {
            let srcDiv = $('#' + srcId);
            if (srcDiv.length === 0) {
                return;
            }

            let w = srcDiv.width();
            if (w >= 450) {
                w = w + 21;
            } else if (w >= 350 && w < 450) {
                w = w + 18;
            } else if (w >= 300 && w < 350) {
                w = w + 16;
            } else {
                w = w + 14;
            }
            $('#' + targetId).width(w);
        };

        // 切换TOC目录展开收缩的相关操作.
        const expandedClass = 'expanded';
        let $tocAside = $('#toc-aside');
        let $mainContent = $('#main-content');
        $('#floating-toc-btn .btn-floating').click(function () {
            if ($tocAside.hasClass(expandedClass)) {
                $tocAside.removeClass(expandedClass).hide();
                $mainContent.removeClass('l9');
            } else {
                $tocAside.addClass(expandedClass).show();
                $mainContent.addClass('l9');
            }
            fixPostCardWidth('artDetail', 'prenext-posts');
        });
        
    });
</script>

    

</main>


<script src="https://cdn.bootcss.com/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script>
    MathJax.Hub.Config({
        tex2jax: {inlineMath: [['$', '$'], ['\(', '\)']]}
    });
</script>



<footer class="page-footer bg-color">
    
    <div class="container row center-align" style="margin-bottom: 15px !important;">
        <div class="col s12 m8 l8 copy-right">
            Copyright&nbsp;&copy;
            <span id="year">2020</span>
            <a href="/about" target="_blank">DongZhou</a>
            |&nbsp;Powered by&nbsp;<a href="https://hexo.io/" target="_blank">Hexo</a>
            |&nbsp;Theme&nbsp;<a href="https://github.com/blinkfox/hexo-theme-matery" target="_blank">Matery</a>
            <br>
            
            &nbsp;<i class="fas fa-chart-area"></i>&nbsp;站点总字数:&nbsp;<span
                class="white-color">60.5k</span>&nbsp;字
            
            
            
            
            
            
            <span id="busuanzi_container_site_pv">
                |&nbsp;<i class="far fa-eye"></i>&nbsp;总访问量:&nbsp;<span id="busuanzi_value_site_pv"
                    class="white-color"></span>&nbsp;次
            </span>
            
            
            <span id="busuanzi_container_site_uv">
                |&nbsp;<i class="fas fa-users"></i>&nbsp;总访问人数:&nbsp;<span id="busuanzi_value_site_uv"
                    class="white-color"></span>&nbsp;人
            </span>
            
            <br>
            
            <span id="sitetime">载入运行时间...</span>
            <script>
                function siteTime() {
                    var seconds = 1000;
                    var minutes = seconds * 60;
                    var hours = minutes * 60;
                    var days = hours * 24;
                    var years = days * 365;
                    var today = new Date();
                    var startYear = "2020";
                    var startMonth = "6";
                    var startDate = "27";
                    var startHour = "0";
                    var startMinute = "0";
                    var startSecond = "0";
                    var todayYear = today.getFullYear();
                    var todayMonth = today.getMonth() + 1;
                    var todayDate = today.getDate();
                    var todayHour = today.getHours();
                    var todayMinute = today.getMinutes();
                    var todaySecond = today.getSeconds();
                    var t1 = Date.UTC(startYear, startMonth, startDate, startHour, startMinute, startSecond);
                    var t2 = Date.UTC(todayYear, todayMonth, todayDate, todayHour, todayMinute, todaySecond);
                    var diff = t2 - t1;
                    var diffYears = Math.floor(diff / years);
                    var diffDays = Math.floor((diff / days) - diffYears * 365);
                    var diffHours = Math.floor((diff - (diffYears * 365 + diffDays) * days) / hours);
                    var diffMinutes = Math.floor((diff - (diffYears * 365 + diffDays) * days - diffHours * hours) /
                        minutes);
                    var diffSeconds = Math.floor((diff - (diffYears * 365 + diffDays) * days - diffHours * hours -
                        diffMinutes * minutes) / seconds);
                    if (startYear == todayYear) {
                        document.getElementById("year").innerHTML = todayYear;
                        document.getElementById("sitetime").innerHTML = "本站已安全运行 " + diffDays + " 天 " + diffHours +
                            " 小时 " + diffMinutes + " 分钟 " + diffSeconds + " 秒";
                    } else {
                        document.getElementById("year").innerHTML = startYear + " - " + todayYear;
                        document.getElementById("sitetime").innerHTML = "本站已安全运行 " + diffYears + " 年 " + diffDays +
                            " 天 " + diffHours + " 小时 " + diffMinutes + " 分钟 " + diffSeconds + " 秒";
                    }
                }
                setInterval(siteTime, 1000);
            </script>
            
            <br>
            
        </div>
        <div class="col s12 m4 l4 social-link social-statis">
    <a href="https://github.com/DongZhouGu" class="tooltipped" target="_blank" data-tooltip="访问我的GitHub" data-position="top" data-delay="50">
        <i class="fab fa-github"></i>
    </a>



    <a href="mailto:gdz678@163.com" class="tooltipped" target="_blank" data-tooltip="邮件联系我" data-position="top" data-delay="50">
        <i class="fas fa-envelope-open"></i>
    </a>







    <a href="tencent://AddContact/?fromId=50&fromSubId=1&subcmd=all&uin=1596586942" class="tooltipped" target="_blank" data-tooltip="QQ联系我: 1596586942" data-position="top" data-delay="50">
        <i class="fab fa-qq"></i>
    </a>







    <a href="/atom.xml" class="tooltipped" target="_blank" data-tooltip="RSS 订阅" data-position="top" data-delay="50">
        <i class="fas fa-rss"></i>
    </a>

</div>
    </div>
</footer>

<div class="progress-bar"></div>


<!-- 搜索遮罩框 -->
<div id="searchModal" class="modal">
    <div class="modal-content">
        <div class="search-header">
            <span class="title"><i class="fas fa-search"></i>&nbsp;&nbsp;搜索</span>
            <input type="search" id="searchInput" name="s" placeholder="请输入搜索的关键字"
                   class="search-input">
        </div>
        <div id="searchResult"></div>
    </div>
</div>

<script src="/js/search.js"></script>
<script type="text/javascript">
$(function () {
    searchFunc("/search.xml", 'searchInput', 'searchResult');
});
</script>

<!-- 回到顶部按钮 -->
<div id="backTop" class="top-scroll">
    <a class="btn-floating btn-large waves-effect waves-light" href="#!">
        <i class="fas fa-arrow-up"></i>
    </a>
</div>


<script src=" https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/materialize/materialize.min.js"></script>
<script src=" https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/masonry/masonry.pkgd.min.js"></script>
<script src=" https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/aos/aos.js"></script>
<script src=" https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/scrollprogress/scrollProgress.min.js"></script>
<script src=" https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/lightGallery/js/lightgallery-all.min.js"></script>
<script src=" /js/matery.js"></script>

<!-- Baidu Analytics -->

<!-- Baidu Push -->

<script>
    (function () {
        var bp = document.createElement('script');
        var curProtocol = window.location.protocol.split(':')[0];
        if (curProtocol === 'https') {
            bp.src = 'https://zz.bdstatic.com/linksubmit/push.js';
        } else {
            bp.src = 'http://push.zhanzhang.baidu.com/push.js';
        }
        var s = document.getElementsByTagName("script")[0];
        s.parentNode.insertBefore(bp, s);
    })();
</script>


<script src=" https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/others/clicklove.js" async="async"></script>


<script async src=" https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/others/busuanzi.pure.mini.js"></script>













<script src=" https://cdn.jsdelivr.net/gh/DongZhouGu/DongZhouGu.github.io/libs/instantpage/instantpage.js" type="module"></script>


<script>
            window.imageLazyLoadSetting = {
                isSPA: false,
                processImages: null,
            };
        </script><script>window.addEventListener("load",function(){var t=/\.(gif|jpg|jpeg|tiff|png)$/i,r=/^data:image\/[a-z]+;base64,/;Array.prototype.slice.call(document.querySelectorAll("img[data-original]")).forEach(function(a){var e=a.parentNode;"A"===e.tagName&&(e.href.match(t)||e.href.match(r))&&(e.href=a.dataset.original)})});</script><script>!function(n){n.imageLazyLoadSetting.processImages=o;var i=n.imageLazyLoadSetting.isSPA,r=Array.prototype.slice.call(document.querySelectorAll("img[data-original]"));function o(){i&&(r=Array.prototype.slice.call(document.querySelectorAll("img[data-original]")));for(var t,e,a=0;a<r.length;a++)t=r[a],e=void 0,0<=(e=t.getBoundingClientRect()).bottom&&0<=e.left&&e.top<=(n.innerHeight||document.documentElement.clientHeight)&&function(){var t,e,n,i,o=r[a];t=o,e=function(){r=r.filter(function(t){return o!==t})},n=new Image,i=t.getAttribute("data-original"),n.onload=function(){t.src=i,e&&e()},n.src=i}()}o(),n.addEventListener("scroll",function(){var t,e;t=o,e=n,clearTimeout(t.tId),t.tId=setTimeout(function(){t.call(e)},500)})}(this);</script></body>

</html>
